Skip to content

✨ Add support for SQLAlchemy polymorphic models #1226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from

Conversation

PaleNeutron
Copy link

@PaleNeutron PaleNeutron commented Nov 26, 2024

Introduce support for SQLAlchemy polymorphic models by adjusting field defaults and handling inheritance correctly in the SQLModel metaclass. Add tests to verify functionality with polymorphic joined and single table inheritance. Refer to #36 .

@PaleNeutron PaleNeutron marked this pull request as ready for review November 26, 2024 03:26
@PaleNeutron PaleNeutron changed the title Support SQLAlchemy polymorphic models Add support for SQLAlchemy polymorphic models Nov 26, 2024
@PaleNeutron PaleNeutron changed the title Add support for SQLAlchemy polymorphic models [feature] Add support for SQLAlchemy polymorphic models Nov 26, 2024
@mmx86
Copy link

mmx86 commented Dec 2, 2024

@tiangolo
Could you please comment on whether this request has a good chance of being merged?
My team and I, being under time constraints, are currently trying to decide whether to commit to this feature already.

Co-authored-by: John Pocock <John-P@users.noreply.github.com>
@ndeybach
Copy link

We are also exploring using SQLModel in our products. This would be quite an ease of life in how we are building our stack.

@tiangolo do you have a timeline as to when could this be merged / what needs to be done ?

@guhur
Copy link

guhur commented Jan 3, 2025

Thanks a lot for this PR! We would love to add this feature in our codebase.

Unfortunately, we could not use this PR along with a custom type.

@PaleNeutron would you mind checking this MRE?

(1) the code works fine if you comment DarkHero or if you comment my_model.

(2) however, it fails if both are in the module!

Code

import json
import typing as t

from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, TypeAdapter

# Warning: we import a deprecated class from the `pydantic` package
# See: https://github.com/pydantic/pydantic/issues/6381
from pydantic._internal._model_construction import ModelMetaclass  # noqa: PLC2701
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.orm import mapped_column
from sqlalchemy.sql.type_api import _BindProcessorType, _ResultProcessorType
from sqlmodel import (
    JSON,
    Column,
    Field,
    Session,
    SQLModel,
    TypeDecorator,
    create_engine,
    select,
)


def pydantic_column_type(  # noqa: C901
    pydantic_type: type[t.Any],
) -> type[TypeDecorator]:
    """
    See details here:
    https://github.com/tiangolo/sqlmodel/issues/63#issuecomment-1081555082
    """
    T = t.TypeVar("T")

    class PydanticJSONType(TypeDecorator, t.Generic[T]):
        impl = JSON()
        cache_ok = False

        def __init__(
            self,
            json_encoder: t.Any = json,
        ):
            self.json_encoder = json_encoder
            super().__init__()

        def bind_processor(self, dialect: Dialect) -> _BindProcessorType[T] | None:
            impl_processor = self.impl.bind_processor(dialect)
            if impl_processor:

                def process(value: T | None) -> T | None:
                    if value is not None:
                        if isinstance(pydantic_type, ModelMetaclass):
                            value_to_dump = pydantic_type.model_validate(value)
                        else:
                            value_to_dump = value
                        value = jsonable_encoder(value_to_dump)
                    return impl_processor(value)

            else:

                def process(value: T | None) -> T | None:
                    if isinstance(pydantic_type, ModelMetaclass):
                        value_to_dump = pydantic_type.model_validate(value)
                    else:
                        value_to_dump = value
                    return jsonable_encoder(value_to_dump)

            return process

        def result_processor(
            self,
            dialect: Dialect,
            coltype: object,
        ) -> _ResultProcessorType[T] | None:
            impl_processor = self.impl.result_processor(dialect, coltype)
            if impl_processor:

                def process(value: T) -> T | None:
                    value = impl_processor(value)
                    if value is None:
                        return None

                    if isinstance(value, str):
                        value = json.loads(value)

                    return TypeAdapter(pydantic_type).validate_python(value)

            else:

                def process(value: T) -> T | None:
                    if value is None:
                        return None

                    if isinstance(value, str):
                        value = json.loads(value)

                    return TypeAdapter(pydantic_type).validate_python(value)

            return process

        def compare_values(self, x: t.Any, y: t.Any) -> bool:
            return x == y

    return PydanticJSONType


class MyModel(BaseModel):
    name: str | None = None


class ComplexModel(SQLModel, table=True):
    id: t.Annotated[
        int | None,
        Field(
            default=None,
            primary_key=True,
        ),
    ] = None
    my_model: t.Annotated[
        MyModel | None,
        Field(
            sa_column=Column(pydantic_column_type(MyModel)),
        ),
    ] = None


class Hero(SQLModel, table=True):
    __tablename__ = "hero"
    id: int | None = Field(default=None, primary_key=True)
    hero_type: str = Field(default="hero")

    __mapper_args__ = {
        "polymorphic_on": "hero_type",
        "polymorphic_identity": "hero",
    }


class DarkHero(Hero):
    dark_power: str = Field(
        default="dark",
        sa_column=mapped_column(
            nullable=False, use_existing_column=True, default="dark"
        ),
    )

    __mapper_args__ = {
        "polymorphic_identity": "dark",
    }


engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
    hero = Hero()
    db.add(hero)
    dark_hero = DarkHero(dark_power="pokey")
    db.add(dark_hero)
    db.commit()
    statement = select(DarkHero)
    result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].dark_power, str)

Corresponding error code

python test.py
Traceback (most recent call last):
  File "/Users/guhur/src/argile-lib-python/test.py", line 101, in <module>
    class DarkHero(Hero):
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/sqlmodel/main.py", line 542, in __new__
    new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_model_construction.py", line 202, in __new__
    complete_model_class(
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_model_construction.py", line 572, in complete_model_class
    generate_pydantic_signature(init=cls.__init__, fields=cls.model_fields, config_wrapper=config_wrapper),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_signature.py", line 159, in generate_pydantic_signature
    merged_params = _generate_signature_parameters(init, fields, config_wrapper)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_signature.py", line 115, in _generate_signature_parameters
    kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/fields.py", line 546, in get_default
    return _utils.smart_deepcopy(self.default)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_utils.py", line 318, in smart_deepcopy
    return deepcopy(obj)  # slowest way when we actually might need a deepcopy
           ^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 265, in _reconstruct
    y = func(*args)
        ^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 264, in <genexpr>
    args = (deepcopy(arg, memo) for arg in args)
            ^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 265, in _reconstruct
    y = func(*args)
        ^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 264, in <genexpr>
    args = (deepcopy(arg, memo) for arg in args)
            ^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 206, in _deepcopy_list
    append(deepcopy(a, memo))
           ^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 211, in _deepcopy_tuple
    y = [deepcopy(a, memo) for a in x]
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 211, in <listcomp>
    y = [deepcopy(a, memo) for a in x]
         ^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 161, in deepcopy
    rv = reductor(4)
         ^^^^^^^^^^^
TypeError: cannot pickle 'module' object

@adsharma
Copy link

This could help with a different kind of polymorphism. Details here.

Specifically:

@fquery.sqlmodel.model()
@dataclass
class Hero:
  ...

Creates two classes Hero (dataclass) and HereoSQLModel (sqlmodel).

Using polymorphism, we could allow the caller to return Hero (dataclass, cheaper when static typing is good enough) or HeroSQLModel if runtime validation is needed.

@PaleNeutron
Copy link
Author

@guhur , good test, I found a bug through it.

@PaleNeutron PaleNeutron changed the title [feature] Add support for SQLAlchemy polymorphic models ✨[feature] Add support for SQLAlchemy polymorphic models Feb 14, 2025
@svlandeg svlandeg added the feature New feature or request label Feb 20, 2025
@svlandeg svlandeg changed the title ✨[feature] Add support for SQLAlchemy polymorphic models ✨ Add support for SQLAlchemy polymorphic models Feb 20, 2025
@ahmdatef
Copy link

Can't wait for this PR to get merged
Thank you guys for your efforts!

@svlandeg
Copy link
Member

Thanks for this contribution @PaleNeutron!

We're currently going through the backlog of PRs and may need some time to catch up. We'll get back to you once someone in the team has been able to review this in detail 🙏

@AlePiccin
Copy link

Any updates on when this feature is gonna be avaliable?

hsiung-bf pushed a commit to hsiung-bf/sqlmodel that referenced this pull request Apr 2, 2025
@KunxiSun
Copy link

KunxiSun commented Apr 3, 2025

Hi @tiangolo,
Any changce we could get this PR merged soon? We would love to see it go live!

@dolfandringa
Copy link

Hi @tiangolo, Any changce we could get this PR merged soon? We would love to see it go live!

Also @svlandeg

@svlandeg
Copy link
Member

Please don't continuously ping maintainers directly. As I said earlier on this PR, we have quite a backlog of PRs and we're managing those as well as we can. Having to respond to these pings actually takes more time and certainly doesn't speed up the overall process of maintaining our open-source repo's.

@jensrischbieth
Copy link

jensrischbieth commented May 15, 2025

@PaleNeutron, seems this request still doesn't allow for polymorphism with other classes as fields, in particular, the Relationship class. Consider the following example:

from sqlmodel import SQLModel, Field, Relationship

class Tool(SQLModel, table=True):
    __tablename__ = 'tool_table'

    id: int = Field(primary_key=True)

    name: str

class Person(SQLModel, table=True):
    __tablename__ = 'person_table'

    id: int = Field(primary_key=True)

    discriminator: str
    name: str

    tool_id: int = Field(foreign_key='tool_table.id')
    tool: Tool = Relationship()

    __mapper_args__ = {
        'polymorphic_on': 'discriminator',
        'polymorphic_identity': 'person_table',
    }


class Worker(Person):
    __mapper_args__ = {
        'polymorphic_identity': 'worker',
    }
pydantic.errors.PydanticSchemaGenerationError: Unable to generate pydantic-core schema for sqlalchemy.orm.base.Mapped[__main__.Tool]. Set `arbitrary_types_allowed=True` in the model_config to ignore this error or implement `__get_pydantic_core_schema__` on your type to fully support it.

If you got this error by calling handler(<some type>) within `__get_pydantic_core_schema__` then you likely need to call `handler.generate_schema(<some type>)` since we do not call `__get_pydantic_core_schema__` on `<some type>` otherwise to avoid infinite recursion.

For further information visit https://errors.pydantic.dev/2.11/u/schema-for-unknown-type

Thanks a lot for your contribution!

@PaleNeutron
Copy link
Author

@jensrischbieth , good test, I forget to deal with relationship in parent class. I'll try to patch it but not sure it can be done since polymorphic relationship in sqlalchemy's have a lot black magic.

@PaleNeutron
Copy link
Author

@jensrischbieth , fixed.

@barrynorman
Copy link

Nice work. When is planned to be released?

@tsuga
Copy link

tsuga commented Jun 30, 2025

Any update for this PR to be merged?

@0x003e
Copy link

0x003e commented Jul 7, 2025

Hello, could someone please approve this merge request ?
Thank you.

@jensrischbieth
Copy link

jensrischbieth commented Jul 8, 2025

@PaleNeutron, I have found another potential issue.

Consider the following linked list:

from typing import Optional
from sqlmodel import Field, Relationship, SQLModel


class BaseNode(SQLModel, table=True):
    __tablename__ = 'node_table'
    
    id: str = Field(primary_key=True)
    node_type: str
    
    # Self-referential relationship - this causes the issue
    next_id: Optional[str] = Field(default=None, foreign_key='node_table.id')
    next: Optional['BaseNode'] = Relationship(
        sa_relationship_kwargs={
            'remote_side': '[BaseNode.id]',
            'uselist': False
        }
    )
    
    __mapper_args__ = {
        'polymorphic_on': 'node_type',
        'polymorphic_identity': 'base',
    }


class EmailNode(BaseNode):
    __mapper_args__ = {
        'polymorphic_identity': 'email',
    }


# Create two nodes
node1 = EmailNode(id="1", node_type="email")
node2 = EmailNode(id="2", node_type="email")

try:
    node1.next = node2  # This fails
except AttributeError as e:
    print(e)

# This works because it's just a regular field
try:
    node1.next_id = "2"  # This works
except Exception as e:
    print(e)
'next' is a ClassVar of `EmailNode` and cannot be set on an instance. If you want to set a value on the class, use `EmailNode.next = value`.

Thanks again for your contributions! Hopefully they can be merged soon!

@PaleNeutron
Copy link
Author

@jensrischbieth Confirmed, working on it.

@budroco
Copy link

budroco commented Aug 13, 2025

Thanks for this PR!

I constructed a simple example, in case anybody wants something to quickly evaluate.

The initial result looks quite promising to me and may mean we won't have to migrate away from SQLModel after all.

In difference to test_polymorphic_model.py I swapped mapped_column with Column to get rid of type errors and just ignored the error on __tablename__ because it's so small and doesn't propagate. Now the entire script is free of type errors (at least in my IDE)!

Let's see how merging this PR goes, it would alleviate a lot of pain for a lot of people.

# This is a uv script. Run with `uv run` and dependencies will be fetched on the fly.
#
# /// script
# requires-python = ">=3.9"
# dependencies = [
#   "sqlmodel @ git+https://github.com/PaleNeutron/sqlmodel@sqlalchemy_polymorphic_support"
# ]
# ///

# Adapted from https://github.com/PaleNeutron/sqlmodel/blob/64f774fb3b5b66ef8d55aab0b26a7733146e60a8/tests/test_polymorphic_model.py

from typing import Optional

from sqlalchemy import Column, Integer
from sqlmodel import Field, Session, SQLModel, create_engine, select


class Animal(SQLModel, table=True):
    __tablename__ = "animal"  # type: ignore

    id: Optional[int] = Field(default=None, primary_key=True)
    name: str
    type: str = Field(default="animal")

    __mapper_args__ = {
        "polymorphic_on": "type",
        "polymorphic_identity": "animal",
    }


class Cat(Animal):
    meow_cuteness: int = Field(sa_column=Column(Integer, nullable=True), default=None)
    __mapper_args__ = {"polymorphic_identity": "cat"}


class Dog(Animal):
    bark_loudness: int = Field(sa_column=Column(Integer, nullable=True), default=None)
    __mapper_args__ = {"polymorphic_identity": "dog"}


if __name__ == "__main__":
    # Create database and session
    engine = create_engine("sqlite:///:memory:", echo=False)
    SQLModel.metadata.create_all(engine)

    with Session(engine) as db:
        db.add_all(
            [
                Animal(name="Generic Animal"),
                Cat(name="Whiskers", meow_cuteness=10),
                Dog(name="Rocky", bark_loudness=8),
            ]
        )
        db.commit()

        animals = db.exec(select(Animal)).all()
        print("All animals:", animals)

        cats = db.exec(select(Cat)).all()
        print("All cats:", cats)

        dogs = db.exec(select(Dog)).all()
        print("All dogs:", dogs)

Result:

All animals: [Animal(type='animal', name='Generic Animal', id=1), Cat(type='cat', name='Whiskers', id=2), Dog(type='dog', name='Rocky', id=3)]
All cats: [Cat(type='cat', name='Whiskers', id=2, meow_cuteness=10)]
All dogs: [Dog(type='dog', name='Rocky', id=3, bark_loudness=8)]

@ssoimada
Copy link

Nice work, please release this!

@piewared
Copy link

Can this be merged in soon? Don't want to have to abandon SQLModel, but polymorphism is really important for us

@a0s
Copy link

a0s commented Aug 21, 2025

MissingGreenlet Error with Polymorphic Models and Lazy Loading

Problem Description

When accessing lazy-loaded attributes on polymorphic SQLAlchemy models in async context, getting MissingGreenlet: greenlet_spawn has not been called; can't call await_only() here error.

Minimal Example

from sqlalchemy import Column, String, ForeignKey, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel, Field

# Base polymorphic model
class BaseConnection(SQLModel, table=True):
    __tablename__ = "connections"
    __mapper_args__ = {"polymorphic_on": "type"}

    id: int = Field(primary_key=True)
    type: str = Field(sa_column=Column(String(50)))

# Derived polymorphic model with FK
class WireGuardConnection(BaseConnection, table=False):
    __mapper_args__ = {"polymorphic_identity": "wireguard"}
    peer_id: int = Field(sa_column=Column(ForeignKey("peers.id"), nullable=True))

async def delete_connection(session: AsyncSession, connection_id: int):
    # Initial query works fine
    stmt = select(BaseConnection).where(BaseConnection.id == connection_id)
    result = await session.execute(stmt)
    connection = result.scalar_one_or_none()

    # This triggers MissingGreenlet error due to lazy loading
    if hasattr(connection, "peer_id") and connection.peer_id:  # ❌ Error here
        print(f"Peer ID: {connection.peer_id}")

Error Details

sqlalchemy.exc.MissingGreenlet: greenlet_spawn has not been called; can't call await_only() here.
Was IO attempted in an unexpected place?

Root Cause

When accessing connection.peer_id, SQLAlchemy tries to perform lazy loading to fetch the attribute, but this happens outside the proper greenlet context.

This issue specifically occurs with polymorphic inheritance when trying to access subclass attributes that weren't loaded in the initial query.

Solution 1: Use session.refresh()

Use session.refresh() to eager load all attributes:

async def delete_connection(session: AsyncSession, connection_id: int):
    stmt = select(BaseConnection).where(BaseConnection.id == connection_id)
    result = await session.execute(stmt)
    connection = result.scalar_one_or_none()

    # Refresh to load all attributes properly
    await session.refresh(connection)  # ✅ Fix

    # Now safe to access attributes
    if hasattr(connection, "peer_id") and connection.peer_id:
        print(f"Peer ID: {connection.peer_id}")

Solution 2: Use joinedload or selectinload

Load relationships eagerly in the initial query:

from sqlalchemy.orm import selectinload

async def delete_connection(session: AsyncSession, connection_id: int):
    stmt = (
        select(BaseConnection)
        .options(selectinload("*"))  # Load all relationships
        .where(BaseConnection.id == connection_id)
    )
    result = await session.execute(stmt)
    connection = result.scalar_one_or_none()

    # Safe to access attributes - already loaded
    if hasattr(connection, "peer_id") and connection.peer_id:
        print(f"Peer ID: {connection.peer_id}")

Solution 3: Type-safe approach

Check the type before accessing subclass attributes:

async def delete_connection(session: AsyncSession, connection_id: int):
    stmt = select(BaseConnection).where(BaseConnection.id == connection_id)
    result = await session.execute(stmt)
    connection = result.scalar_one_or_none()

    # Refresh to be safe
    await session.refresh(connection)

    # Type-safe check
    if connection.type == "wireguard":
        wireguard_conn = connection  # Type hint: WireGuardConnection
        if wireguard_conn.peer_id:
            print(f"Peer ID: {wireguard_conn.peer_id}")

Environment

  • SQLAlchemy: 2.x (async)
  • SQLModel/FastAPI stack
  • PostgreSQL with asyncpg

Related Issues

Prevention

  1. Always use await session.refresh(obj) before accessing subclass attributes
  2. Use eager loading strategies (selectinload, joinedload) when you know you'll need related data
  3. Consider using explicit type checks instead of hasattr() for polymorphic models

@PaleNeutron
Copy link
Author

PaleNeutron commented Aug 22, 2025

@a0s , I think you will face the same error in pure sqlalchemy, check this test below:

full test code
import asyncio

from sqlalchemy import Column, ForeignKey, Integer, String, select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import declarative_base, joinedload, relationship, sessionmaker

# A base for our declarative models
Base = declarative_base()


# A hypothetical Peer model
class Peer(Base):
    __tablename__ = "peers"
    id = Column(Integer, primary_key=True)
    name = Column(String)

    def __repr__(self):
        return f"Peer(id={self.id}, name='{self.name}')"


# Base polymorphic model
class BaseConnection(Base):
    __tablename__ = "connections"
    __mapper_args__ = {"polymorphic_on": "type"}

    id = Column(Integer, primary_key=True)
    type = Column(String(50))

    def __repr__(self):
        return f"BaseConnection(id={self.id})"


# Derived polymorphic model with a foreign key
class WireGuardConnection(BaseConnection):
    __mapper_args__ = {"polymorphic_identity": "wireguard"}
    peer_id = Column(ForeignKey("peers.id"))

    # This is the relationship that would cause lazy loading.
    # When you access `connection.peer`, SQLAlchemy will query the `peers` table.
    peer = relationship("Peer")

    def __repr__(self):
        return f"WireGuardConnection(id={self.id})"


async def setup_db(engine):
    """Create database tables."""
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)


async def seed_data(session: AsyncSession):
    """Insert some example data."""
    print("--- Inserting example data ---")
    peer1 = Peer(name="user_A_peer")
    connection1 = WireGuardConnection()
    connection1.peer = peer1  # SQLAlchemy will automatically set peer_id

    session.add_all([peer1, connection1])
    await session.commit()
    print("Data insertion complete.")


async def demonstrate_bug(Session: sessionmaker):
    """Demonstrates the MissingGreenlet error caused by lazy loading."""
    async with Session() as session:
        # Get the connection object
        stmt = select(BaseConnection).where(BaseConnection.id == 1)
        result = await session.execute(stmt)
        connection = result.scalar_one_or_none()

        # Attempt to access a lazy-loaded attribute after the session is closed
        print("\n--- Attempting to access lazy-loaded 'connection.peer_id' ---")
        print(f"Current object: {connection}")
        try:
            # ❌ ERROR: This will attempt to execute a new database query, but the session is inactive
            # This is the MissingGreenlet error you might encounter
            print(f"✅ Access successful: peer id is {connection.peer_id}")
        except Exception as e:
            print(f"❌ Caught error: {type(e).__name__}: {e}")


async def demonstrate_fix(Session: sessionmaker):
    """Demonstrates how to fix the issue using eager loading."""
    async with Session() as session:
        # Use joinedload to eagerly load the 'peer' relationship
        # select(WireGuardConnection) ensures we're loading the subclass that has the 'peer' relationship
        stmt = select(WireGuardConnection).where(WireGuardConnection.id == 1)
        result = await session.execute(stmt)
        connection = result.scalar_one_or_none()

        print("\n--- Attempting to access eagerly-loaded 'connection.peer_id' ---")
        print(f"Current object: {connection}")
        try:
            # ✅ OK: The 'peer' data was loaded in the initial query, so no new database query is needed
            print(f"✅ Access successful: peer id is {connection.peer_id}")
            print("Since the data was eagerly loaded, no new query is triggered.")
        except Exception as e:
            # No error will occur here
            print(f"❌ Caught error: {type(e).__name__}: {e}")


async def main():
    # Use an in-memory database
    engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=True)

    # Create an async session with a crucial setting: expire_on_commit=False
    Session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)

    await setup_db(engine)

    # Use a session to insert data
    async with Session() as session:
        await seed_data(session)

    # Run the demonstrations
    await demonstrate_bug(Session)
    await demonstrate_fix(Session)

    # Close the engine connection pool
    await engine.dispose()


if __name__ == "__main__":
    asyncio.run(main())

If you look at the SQL queries being logged to the console, you'll see exactly why this is happening.

The Lazy-Loading Issue

Initially, your code using select(BaseConnection) explicitly tells SQLAlchemy to only fetch the id and type columns. That's why the first query is:

SELECT connections.id, connections.type FROM connections

Notice that the peer_id column isn't part of this initial selection.

When your code later tries to access the connection.peer attribute, SQLAlchemy attempts to lazy-load this unloaded data. This triggers a second, immediate query to fetch just the missing peer_id:

SELECT connections.peer_id AS connections_peer_id FROM connections...

The problem is that this lazy-loading operation doesn't work correctly in an async context. It tries to perform I/O without the required await, which causes the MissingGreenlet error you're seeing.

BTW, even the code will work in sync sqlalchemy, it will harm performance seriously and should be treated as a bug. It will perform an implicit select query for each item.


Recommendation for FastAPI Users

Because of complex I/O issues like this, I don't recommend using async database sessions and async api functions for junior developers.

FastAPI (via Starlette) is designed to run synchronous functions, like a standard database call, in a separate thread pool. This means a regular, synchronous I/O operation will not block the main application thread. For this reason, it's often much simpler and safer to stick with synchronous database sessions in your app (most internal web app's concurrency is lower than your worker number). If you use sync io function in async api route, you may encounter issues with blocking the event loop which is the worst case scenario for performance in a fastapi application.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.